#!/usr/bin/python
# -*- encoding: utf-8 -*-

from logger import setup_logger
from model2 import BiSeNet
from face_dataset import FaceMask

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.distributed as dist

import os
import os.path as osp
import logging
import time
import numpy as np
from tqdm import tqdm
import math
from PIL import Image
import torchvision.transforms as transforms
import cv2

def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
    # Colors for all 20 parts
    part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
                   [255, 0, 85], [255, 0, 170],
                   [0, 255, 0], [85, 255, 0], [170, 255, 0],
                   [0, 255, 85], [0, 255, 170],
                   [0, 0, 255], [85, 0, 255], [170, 0, 255],
                   [0, 85, 255], [0, 170, 255],
                   [255, 255, 0], [255, 255, 85], [255, 255, 170],
                   [255, 0, 255], [255, 85, 255], [255, 170, 255],
                   [0, 255, 255], [85, 255, 255], [170, 255, 255]]

    im = np.array(im)
    vis_im = im.copy().astype(np.uint8)
    vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
    vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
    vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255

    num_of_class = np.max(vis_parsing_anno)

    for pi in range(1, num_of_class + 1):
        index = np.where(vis_parsing_anno == pi)
        vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]

    vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
    # print(vis_parsing_anno_color.shape, vis_im.shape)
    vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)

    # Save result or not
    if save_im:
        cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])

    # return vis_im

def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth',export_onnx=False):

    if not os.path.exists(respth):
        os.makedirs(respth)

    n_classes = 2
    net = BiSeNet(n_classes=n_classes)
    net.cuda()
    save_pth = osp.join('res/cp', cp)
    net.load_state_dict(torch.load(save_pth,weights_only=True))
    net.eval()
    dummy_input = torch.randn(1, 3, 512, 512).cuda()

    # 3️⃣ 导出为 ONNX（可选）
    if export_onnx:
        onnx_path = osp.join(respth, "model.onnx")  # ONNX 保存路径
        torch.onnx.export(
            net, dummy_input, onnx_path,
            input_names=["input"],  # 输入名称
            output_names=["output"],  # 输出名称
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},  # 允许动态 batch
            opset_version=11  # ONNX 版本
        )
        print(f"ONNX 模型已导出: {onnx_path}")
    to_tensor = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    with torch.no_grad():
        for image_path in os.listdir(dspth):
            img = Image.open(osp.join(dspth, image_path))
            image = img.resize((512, 512), Image.BILINEAR)
            img = to_tensor(image)
            img = torch.unsqueeze(img, 0)
            img = img.cuda()
            out = net(img)[0]
            parsing = out.squeeze(0).cpu().numpy().argmax(0)

            vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))


import onnxruntime as ort
def test_onnx(respth='./res/test_res', dspth='./data', model_path='/home/liutengyu/face-parsing.PyTorch/res/test_res/model.onnx'):
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
    net_faceseg = ort.InferenceSession(model_path,providers=['CUDAExecutionProvider'])
    for image_path in os.listdir(dspth):
        cropped_face = cv2.imread(osp.join(dspth, image_path))
        image_rgb = cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB)
        img1 = cv2.resize(image_rgb, (512,512), interpolation=cv2.INTER_LINEAR)
        img1 = img1.astype(np.float32)
        img1 /= 255.0
        img1 = (img1 - mean) / std
        blob = np.expand_dims(np.transpose(img1, (2, 0, 1)), axis=0).astype(np.float32)
        out = net_faceseg.run(None, {net_faceseg.get_inputs()[0].name: blob})[0]
        parsing = out.squeeze(0).argmax(0).astype(np.uint8)
        parsing = cv2.resize(parsing, (cropped_face.shape[1],cropped_face.shape[0]), interpolation=cv2.INTER_NEAREST)
        parsing[parsing == 1] = 255
        save_path = osp.join(respth, image_path)
        cv2.imwrite(save_path, parsing)
    return parsing

if __name__ == "__main__":
    #setup_logger('./res')
    evaluate(dspth='/home/liutengyu/face-parsing.PyTorch/test',cp='79999_iter.pth')
    test_onnx(dspth='/home/liutengyu/face-parsing.PyTorch/test')
